{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Library imports "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import numpy as np\n",
    "from skimage.io import imsave\n",
    "import os\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constants and flags "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "img_height = 28\n",
    "img_width = 28\n",
    "img_size = img_height * img_width\n",
    "\n",
    "to_train = True\n",
    "to_restore = False\n",
    "output_path = \"output\"\n",
    "\n",
    "max_epoch = 500\n",
    "\n",
    "h1_size = 150\n",
    "h2_size = 300\n",
    "z_size = 100\n",
    "batch_size = 256\n",
    "ngf = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function definitions "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convolution layer \n",
    "Defines a general 2D convolution layer with batch normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def general_conv2d(inputconv, name=\"conv2d\",\n",
    "                   o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, \n",
    "                   stddev=0.02, padding=None, \n",
    "                   do_norm=True, do_relu=True):\n",
    "    '''Defines a general 2D convolution layer with batch normalization'''\n",
    "    \n",
    "    with tf.variable_scope(name):\n",
    "        initializer = tf.truncated_normal_intializer(stddev=stddev)\n",
    "        w = tf.get_variable('w',\n",
    "                            [f_h, f_w, inputconv.get_shape(-1), o_d], \n",
    "                            initializer=initializer)\n",
    "        conv = tf.nn.conv2d(inputconv,\n",
    "                            filter=w,strides=[1, s_w, s_h, 1],\n",
    "                            padding=padding)\n",
    "        biases = tf.get_variable('b',\n",
    "                                 [o_d],\n",
    "                                 initializer=tf.constant_initializer(0.0))\n",
    "        conv = tf.nn.bias_add(conv,biases)\n",
    "        \n",
    "        # Add batch_norm layer\n",
    "        if do_norm:\n",
    "            dims = conv.get_shape()\n",
    "            scale = tf.get_variable('scale',\n",
    "                                    [dims[1],dims[2],dims[3]],\n",
    "                                    tf.constant_initializer(1))\n",
    "            beta = tf.get_variable('beta',\n",
    "                                   [dims[1],dims[2],dims[3]],\n",
    "                                   tf.constant_initializer(0))\n",
    "            conv_mean, conv_var = tf.nn.moments(conv,[0])\n",
    "            conv = tf.nn.batch_normalization(conv, conv_mean, conv_var, beta, scale, 0.001)\n",
    "        # Add ReLU activation\n",
    "        if do_relu:\n",
    "            conv = tf.nn.relu(conv,0)\n",
    "    return conv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Resnet block "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def build_resnet_block(inputres, dim, name=\"resnet\"):\n",
    "    out_res = inputres\n",
    "    with tf.variable_scope(name):\n",
    "        out_res = general_conv2d(inputres, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c1\")\n",
    "        out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c2\", do_relu=False)\n",
    "    return tf.nn.relu(out_res + inputres)\n",
    "\n",
    "def build_generator_resnet_6blocks(inputgen, name=\"generator\"):\n",
    "    with tf.variable_scope(name):\n",
    "        f = 7\n",
    "        ks = 3\n",
    "        o_c1 = general_conv2d(inputgen, ngf, f, f, 1, 1, 0.02, \"SAME\", \"c1\")\n",
    "        o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, None, \"c2\")\n",
    "        o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, None, \"c3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Show results "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):\n",
    "    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5\n",
    "    img_h, img_w = batch_res.shape[1], batch_res.shape[2]\n",
    "    grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)\n",
    "    grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)\n",
    "    img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)\n",
    "    for i, res in enumerate(batch_res):\n",
    "        if i >= grid_size[0] * grid_size[1]:\n",
    "            break\n",
    "        img = (res) * 255\n",
    "        img = img.astype(np.uint8)\n",
    "        row = (i // grid_size[0]) * (img_h + grid_pad)\n",
    "        col = (i % grid_size[1]) * (img_w + grid_pad)\n",
    "        img_grid[row:row + img_h, col:col + img_w] = img\n",
    "    imsave(fname, img_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and testing routines "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "\n",
    "    x_data = tf.placeholder(tf.float32, [batch_size, img_size], name=\"x_data\")\n",
    "    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n",
    "    keep_prob = tf.placeholder(tf.float32, name=\"keep_prob\")\n",
    "    global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n",
    "\n",
    "    x_generated, g_params = build_generator(z_prior)\n",
    "    y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)\n",
    "\n",
    "    d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))\n",
    "    g_loss = - tf.log(y_generated)\n",
    "\n",
    "    optimizer = tf.train.AdamOptimizer(0.0001)\n",
    "\n",
    "    d_trainer = optimizer.minimize(d_loss, var_list=d_params)\n",
    "    g_trainer = optimizer.minimize(g_loss, var_list=g_params)\n",
    "\n",
    "    init = tf.initialize_all_variables()\n",
    "\n",
    "    saver = tf.train.Saver()\n",
    "\n",
    "    sess = tf.Session()\n",
    "\n",
    "    sess.run(init)\n",
    "\n",
    "    if to_restore:\n",
    "        chkpt_fname = tf.train.latest_checkpoint(output_path)\n",
    "        saver.restore(sess, chkpt_fname)\n",
    "    else:\n",
    "        if os.path.exists(output_path):\n",
    "            shutil.rmtree(output_path)\n",
    "        os.mkdir(output_path)\n",
    "\n",
    "\n",
    "    z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
    "\n",
    "    for i in range(sess.run(global_step), max_epoch):\n",
    "        for j in range(60000 / batch_size):\n",
    "            print(\"epoch:%s, iter:%s\" % (i, j))\n",
    "            x_value, _ = mnist.train.next_batch(batch_size)\n",
    "            x_value = 2 * x_value.astype(np.float32) - 1\n",
    "            z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
    "            sess.run(d_trainer,\n",
    "                     feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n",
    "            if j % 1 == 0:\n",
    "                sess.run(g_trainer,\n",
    "                         feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n",
    "        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})\n",
    "        show_result(x_gen_val, \"output/sample{0}.jpg\".format(i))\n",
    "        z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
    "        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})\n",
    "        show_result(x_gen_val, \"output/random_sample{0}.jpg\".format(i))\n",
    "        sess.run(tf.assign(global_step, i + 1))\n",
    "        saver.save(sess, os.path.join(output_path, \"model\"), global_step=global_step)\n",
    "\n",
    "\n",
    "def test():\n",
    "    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n",
    "    x_generated, _ = build_generator(z_prior)\n",
    "    chkpt_fname = tf.train.latest_checkpoint(output_path)\n",
    "\n",
    "    init = tf.initialize_all_variables()\n",
    "    sess = tf.Session()\n",
    "    saver = tf.train.Saver()\n",
    "    sess.run(init)\n",
    "    saver.restore(sess, chkpt_fname)\n",
    "    z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
    "    x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})\n",
    "    show_result(x_gen_val, \"output/test_result.jpg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main/driver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if __name__ == '__main__':\n",
    "    if to_train:\n",
    "        train()\n",
    "    else:\n",
    "        test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
